A gentle introduction to SumProductTransform library

This introduction uses several unregistered libraries, namely ToyProblems.jl, SumProductTransform.jl which depends on Unitary.jl. The best is to instantiate environment in example/ directory, which should continue all you need including Pluto.

The intruduction starts with a classic Gaussian Mixture Model, continues with a simple Sum Product Network and graduates with Sum Product Transform Network.

Before we dive into real business, we import libraries and define a convenient function for plotting densities and data.

42.4 μs
27.8 ms

A plotting function will show the density of a fitted model and that of with training data on top

6.8 μs
126 μs

Let's create training samples from Flower dataset with nine petals.

9.2 μs
331 μs

Initialize dimension of data d, batchsize in stochastic gradient descend, and number of training steps

7.7 μs
8.3 μs

Gaussian Mixture Model

gmm with 144 components

7.2 μs
67.2 ms
556 s
Plots.jl
-4-2024-4-2024-5.0-2.50.02.55.0-5.0-2.50.02.55.0
y200.020.040.060.080.10.1200.020.040.060.080.10.12
12.4 s

Sum Product network

6.1 μs
251 ms
448 s
Plots.jl
-4-2024-4-2024-5.0-2.50.02.55.0-5.0-2.50.02.55.0
y200.020.040.060.080.10.120.140.1600.020.040.060.080.10.120.140.16
14.6 s

Sum Product Transform network

with affine transformations and Normal distribution on leaves

7.9 μs
158 ms
172 s
Plots.jl
-4-2024-4-2024-5.0-2.50.02.55.0-5.0-2.50.02.55.0
y200.10.20.30.40.500.10.20.30.40.5
7.8 s

Sum Product Transform network

with nonlinear transformation on leaves

9.8 μs
2.8 s
2.2 undefineds
Plots.jl
-4-2024-4-2024-5.0-2.50.02.55.0-5.0-2.50.02.55.0
y200.050.10.150.20.250.30.350.400.050.10.150.20.250.30.350.4
40.6 s